"""
This script reads 2 input files:
    - shapefile, and 
    - raster file of population data (.tif),
It then aggregates the raster data to 10 km grid cells, filters the grid cells by a population
threshold of POPULATION_THRESHOLD(=100) inhabitants, and writes the filtered grid cells to a text file with the
following columns: 
    - lat 
    - lon 
    - from 
    - to
"""

# Import necessary libraries
import os, time, sys, re
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
import geopandas as gpd
import rasterio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from utils.config import load_config
from utils.logger import set_logger

# Load the configuration file
config = load_config()
logger = set_logger()

# Setting the working directory to the directory of the script
os.chdir(os.path.dirname(__file__))

# CONSTANTS
POPULATION_THRESHOLD = config['pop_threshold']   # 100
START_DATETIME       = config['start_datetime']  # '2020-01-01T00:00:00Z'
END_DATETIME         = config['end_datetime']    # '2022-12-31T23:59:59Z'
ESPG                 = config['espg']            # 4326

COUNTRY_NAME = config['country_name']  # e.g. "mexico"
COUNTRY_CODE = config['country_code']  # e.g. "mx"

# Convert the start and end datetimes to the format `YYYYMM`
START_YEAR_MONTH = re.sub(r'[^0-9]', '', START_DATETIME[:7])
END_YEAR_MONTH = re.sub(r'[^0-9]', '', END_DATETIME[:7])

# PATHS
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
INPUT_DIR     = os.path.join(PROJECT_ROOT, '../../../data', COUNTRY_NAME)
OUTPUT_DIR    = os.path.join(PROJECT_ROOT, '../../../output', COUNTRY_NAME)
INPUT_SHP_FILE  = os.path.join(INPUT_DIR, f'{COUNTRY_CODE}-query-setup/state-shape/{COUNTRY_CODE}-states.shp')  # f"data/{COUNTRY_NAME}/{COUNTRY_CODE}-query-setup/state-shape/{COUNTRY_CODE}-states.shp"
INPUT_TIF_FILE = os.path.join(INPUT_DIR, f'{COUNTRY_CODE}-query-setup/pop-grid/{COUNTRY_CODE}_pd_2020_1km_Aggregated.tif')  # f"data/{COUNTRY_NAME}/{COUNTRY_CODE}-query-setup/pop-grid/{COUNTRY_CODE}_pd_2020_1km_Aggregated.tif"
OUTPUT_TXT_FILE  = os.path.join(OUTPUT_DIR, f"{COUNTRY_CODE}-query-setup", f"{START_YEAR_MONTH}_{END_YEAR_MONTH}_{COUNTRY_CODE}-locations.txt")  # f"output/{COUNTRY_NAME}/{COUNTRY_CODE}-query-setup/{START_YEAR_MONTH}_{END_YEAR_MONTH}_{COUNTRY_CODE}-locations.txt"
PLOT_PNG_FILE = os.path.join(OUTPUT_DIR, "population_grid.png")

# Function to load and transform shapefile
def load_shapefile(filepath):
    logger.info(f"Loading shapefile from {filepath}...")
    try:
        polygon_gdf = gpd.read_file(filepath)
        polygon_gdf = polygon_gdf.to_crs(epsg=ESPG)  # Transform to EPSG:4326
        logger.info(f"Shapefile loaded successfully.")
    except Exception as e:
        logger.error(f"Error loading shapefile: {e}!")
        polygon_gdf = None
    finally:
        return polygon_gdf

# Function to load and aggregate raster data
def load_transform_raster(filepath, factor=10):
    logger.info(f"Loading raster data from {filepath}...")
    try:        
        with rasterio.open(filepath) as src:
            # Read the raster data as an array
            r = src.read(1)  # Read the first band
            
            # Get the dimensions of the original raster
            height, width = r.shape
            
            # Ensure that the raster dimensions are divisible by the aggregation factor
            new_height = (height // factor) * factor
            new_width = (width // factor) * factor
            
            # Trim the raster to be divisible by the factor
            r_trimmed = r[:new_height, :new_width]

            # Aggregate the raster data by summing over blocks of 'factor' size
            aggregated_data = r_trimmed.reshape((new_height // factor, factor, new_width // factor, factor)).sum(axis=(1, 3))

            # Generate x and y coordinates for the aggregated raster
            x = np.linspace(src.bounds.left, src.bounds.right, aggregated_data.shape[1])
            y = np.linspace(src.bounds.top, src.bounds.bottom, aggregated_data.shape[0])

            # Create a DataFrame from the raster data
            xv, yv = np.meshgrid(x, y)
            r_df = pd.DataFrame({'x': xv.ravel(), 'y': yv.ravel(), 'value': aggregated_data.ravel()})
            r_df['x'] = r_df['x'].round(4)
            r_df['y'] = r_df['y'].round(4)
            logger.info(f"Raster data loaded successfully.")
    except Exception as e:
        logger.error(f"Error loading raster data: {e}!")
        r_df = None
    finally:
        return r_df

# Function to filter raster data by population threshold
def filter_population(r_df, threshold=100):
    r_df['include'] = r_df['value'] > threshold
    return r_df

# Function to plot the data
def plot_data(r_df, polygon_gdf):
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot population grid
    sns.scatterplot(data=r_df, x='x', y='y', hue=np.sqrt(r_df['value']),
                    palette='Spectral', ax=axes[0], marker="s", edgecolor=None, s=20)
    polygon_gdf.boundary.plot(ax=axes[0], color='black', linewidth=0.5)
    axes[0].set_title(f"10 km population grid of {COUNTRY_NAME}\n(2020, n = {len(r_df):,})")
    
    # Plot grids with population > 100
    sns.scatterplot(data=r_df, x='x', y='y', hue='include',
                    palette='Set2', ax=axes[1], marker="s", edgecolor=None, s=20)
    polygon_gdf.boundary.plot(ax=axes[1], color='black', linewidth=0.5)
    axes[1].set_title(f"Grids with more than {POPULATION_THRESHOLD} inhabitants\n(n = {r_df['include'].sum():,})")

    plt.tight_layout()
    # plt.show()
    # Save the plot
    fig.savefig(PLOT_PNG_FILE)

# Function to write filtered locations to a file
def write_locations_file(r_df, output_filepath):
    filtered_df = r_df[r_df['include']]
    locations_df = pd.DataFrame({
        'lat': filtered_df['y'],
        'lon': filtered_df['x'],
        'from': START_DATETIME,
        'to': END_DATETIME
    })
    logger.info(f"Number of locations: {len(locations_df)}")
    logger.info(f"Writing locations to {output_filepath}...")
    try:
        # Create the output directory if it does not exist
        os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
        locations_df.to_csv(output_filepath, index=False, header=False)
        logger.info(f"Locations written successfully.")
    except Exception as e:
        logger.error(f"Error writing locations: {e}!")

# Main function to execute all steps
def main():
    # Load shapefile
    polygon_gdf = load_shapefile(INPUT_SHP_FILE)
    if polygon_gdf is None:
        logger.error("Error loading shapefile!")
        sys.exit(1)

    # Load and transform raster data
    r_df = load_transform_raster(INPUT_TIF_FILE)
    if r_df is None:
        logger.error("Error loading raster data!")
        sys.exit(1)

    # Filter by population threshold
    r_df = filter_population(r_df, threshold=POPULATION_THRESHOLD)
    if r_df is None:
        logger.error("Error filtering raster data!")
        sys.exit(1)

    # Plot data
    plot_data(r_df, polygon_gdf)

    # Write locations file
    write_locations_file(r_df, OUTPUT_TXT_FILE)

if __name__ == "__main__":
    start = time.time()
    main()
    end = time.time()
    logger.info(f"Execution time: {end - start:.2f} seconds")
